#include "pch.h"
//#define GLOBALS			// global variables for image filters are to be declared in this file
 						// (and extern'd elsewhere)
#include "StripeCode.h"
#include "stdio.h"
#include <iostream>

using namespace std;

// global objects
MTRNG RNG;
vector<int> DB;			// database of photoIDs
int QUERY;				// query image

// IR algorithms
ImageFeatures *imgFeatures = NULL;
void (*db_query)(int &, double &);
void db_query_random(int &, double &);
void db_query_generic(int &, double&);

// command line parameters
char ARG_input[256];	// input dataset (output file of dataset compiler)
char ARG_algorithm[32]; // algorithm to use
int ARG_ipa = 0;		// number of images per animal
int ARG_quiet = 0;		// non-interactive mode
int ARG_randseed = 31183; // (filled in with time(NULL) in main()), or specified on cmdline
int ARG_trials = 5;     // random trials
int ARG_sc_avg_cost = 0;    // StripeCode: return average cost over all Stripestrings
double ARG_sc_del_cost = 0.5;
int ARG_sc_use_abslen = 0;  // use normalized absolute lengths, not ratios
int ARG_exploratory = 0;    // exploratory mode (only execute function Exploratory())

CMDLINE_PARAMETERS cmd[] = {
	{ "--input", CMDLINE_STRING, ARG_input, 1, 255, 1,
	    "dataset file generated by the compiler program" },
	{ "--ipa", CMDLINE_INTEGER, &ARG_ipa, 1, 100, 1,
	    "number of images per animal" },
	{ "--method", CMDLINE_STRING, &ARG_algorithm, 1, 31, 1,
	    "algorithm ('random', 'stripecode', 'mrhisto')" },
	{ "--trials", CMDLINE_INTEGER, &ARG_trials, 5, 50000, 0,
	    "number of random trials" },
	{ "--quiet", CMDLINE_TRUE, &ARG_quiet, 0, 1, 0,
		"non-interactive mode" },
	{ "--seed", CMDLINE_INTEGER, &ARG_randseed, 0, 100000, 0,
		"random seed" },
    { "--sc:avg", CMDLINE_TRUE, &ARG_sc_avg_cost, 0, 1, 0,
        "StripeCode: distance = average cost over StripeStrings (as opposed to min cost)" },
    { "--sc:INDELCOST", CMDLINE_DOUBLE, &ARG_sc_del_cost, 0, 10, 0,
        "StripeCode: delete cost (usually a real in [0,1])" },
    { "--explore", CMDLINE_TRUE, &ARG_exploratory, 0, 1, 0,
        "exploratory mode (no experiments)"   },
    {
      "--sc:abslen", CMDLINE_TRUE, &ARG_sc_use_abslen, 0, 1, 0,
         "StripeCode: use normalized absolute lengths, not ratios"
    },
	{ "--help", CMDLINE_HELP, NULL, 0, 0, 0, "" },
	{ NULL, 0, NULL, 0, 0, 0, NULL }
};


// main data structures
map<string, vector<int> > animal_to_photos;	// maps animalID to photoID
map<int, string>		  photo_to_animal;	// maps photoID to animalID
vector<int>				  photos;			// list of all photoIDs
vector<string>			  animals;			// list of all animalIDs
map<int, ImageFeatures*>   photo_to_features;
vector<string> animal;		// unique animals in no particular order -- i.e., keys(animal_to_photos)

map<int, string> photoid_to_animalname; // read from SightingData.txt

// General purpose exploratory mode
// -- animal, photographs, stripes/photograph are all loaded
// -- do whatever you want here
void Exploratory() {
    fprintf(stderr, "Exploratory mode.\n");

    vector<int> &plist = animal_to_photos[animals[0]];
    for(int p = 1; p < plist.size(); p++) {
        printf("Plotting comparison of photo %d and %d for animal '%s'.\n", plist[0], plist[p], animals[0].c_str());
        wxString fname;
        fname.Printf(_("comp-%d-%d.png"), plist[0], plist[p]);
        wxImage *img = ((StripeCode*)photo_to_features[plist[0]])->plotComparison(*(StripeCode*)photo_to_features[plist[p]]);
        if(img)
            img->SaveFile(fname, wxBITMAP_TYPE_PNG);
        else {
            fprintf(stderr, "Cannot write jpeg file.\n");
            return;
        }
        delete img;
    }

    // compare non-animals
    vector<int> &plist2 = animal_to_photos[animals[2]];
    for(int p = 0; p < plist2.size(); p++) {
        printf("Plotting NEGATIVE comparison of photo %d ('%s') and photo %d ('%s').\n", plist[0], animals[0].c_str(), plist2[p], animals[2].c_str());
        wxString fname;
        fname.Printf(_("compNEG-%d-%d.png"), plist[0], plist2[p]);
        wxImage *img = ((StripeCode*)photo_to_features[plist[0]])->plotComparison(*(StripeCode*)photo_to_features[plist2[p]]);
        if(img)
            img->SaveFile(fname, wxBITMAP_TYPE_PNG);
        else {
            fprintf(stderr, "Cannot write jpeg file.\n");
            return;
        }
        delete img;
    }

}

template<class T>
void shuffle (std::vector<T> & deck) {
    int deckSize = (int) deck.size();
    while(deckSize > 1) {
       int k = RNG.genrand_real2() * deckSize;
       deckSize--;
	   swap(deck[deckSize], deck[k]);
    }
}

int read_sighting_data() {
	const int FILENAME_MAXLEN=255;
	char filename[FILENAME_MAXLEN+1];	// input dataset (output file of dataset compiler)
	strncpy(filename, ARG_input, FILENAME_MAXLEN);

	// try to find file SightingData.txt in the same directory as ARG_input
	char *p = filename;
	while (*p!=0 && p<filename+FILENAME_MAXLEN) p++;
	assert(p<filename+FILENAME_MAXLEN);
	while (*p!='/' && *p!='\\' && p>filename) p--;
	assert(p>=filename);
	p++;
	strncpy(p, "SightingData.csv", FILENAME_MAXLEN-(p-filename));
	FILE *fp = fopen(filename, "r");
	if (fp==NULL) return 0;

	// read line by line
	const int buf_size=1024;
	char buf[buf_size];
	while (fgets(buf, buf_size, fp)!=NULL) {
		if (buf[0]=='#') continue;

		char *p=buf;
		char *photo_id_str=buf; 	// column 1 = photo id
		while (*p!=',' && p<buf+buf_size) p++;
		assert (p<buf+buf_size);
		*p++=0;
		while (*p!=',' && p<buf+buf_size) p++;
		assert (p<buf+buf_size);
		*p++=0;
		while (*p!=',' && p<buf+buf_size) p++;
		assert (p<buf+buf_size);
		*p++=0;
		char *animal_name=p;	//column 4 = animal name
		while (*p!=',' && p<buf+buf_size) p++;
		assert (p<buf+buf_size);
		*p=0;

		int photo_id =  atoi(photo_id_str);
		photoid_to_animalname[photo_id] = animal_name;
	}
	fclose(fp);
	return 1;
}

// reads the input file specified on the command line (in global 'ARG_input')
int read_dataset() {
	// read and parse image features
	FILE *fp = fopen(ARG_input, "r");
	if(!fp)
		return printf("Cannot open '%s'\n", ARG_input);
	const int buf_size = 4096;
	char buf[buf_size+1]; int line = 0;
	int no_animal_name_error_count=0;
	int pic_count=0;
	while(fgets(buf, buf_size, fp) && ++line) {
		if(strncmp(buf, "ANIMAL ", 7)==0) {
            pic_count++;
		    // get animal name and photo id
			char *p;
			for(p=buf+7; *p && p<buf+buf_size; p++) {    // fixed-format matching
				if (*p==' ') {
					*p++ = NULL;
					break;
				}
			}
			int photoid;
			string aname;
			if(*(buf+7) && *p) {
				aname = buf+7;
				photoid = atoi(p);
			} else if (*p==0 && *(buf+7)!=0) {
				//aname = buf+7; // get from SightingData.csv
				photoid = atoi(buf+7);
				if (photoid_to_animalname.empty() && no_animal_name_error_count==0) {
					fprintf(stderr, "WARNING Cannot load animal names from SightingData.csv\n");
				}
				if (photoid_to_animalname.find(photoid)==photoid_to_animalname.end()) {
					if (++no_animal_name_error_count<=0) {
						fprintf(stderr, "WARNING Cannot find animal name. Photo id %d does not exists in SightingData.csv\n", photoid);
					}
					continue;
				}
				aname = photoid_to_animalname.at(photoid);
			} else {
			    fprintf(stderr, "Invalid line in file: line %d\n", line);
			    fprintf(stderr, "buf+7: \"%s\"\n", buf+7);
			    fprintf(stderr, "p: \"%s\"\n", p);
			    fclose(fp);
			    return 0;
			}

			if(photoid < 1) {
				fprintf(stderr, "Error: photoID must be greater than 1 (line %d)\n", line);
				return 0;
			}

            // save photo pointer and read image features
			if(!imgFeatures->read(fp)) {
			    fprintf(stderr, "Error: malformed image feature set (line %d)\n", line);
			    return 0;
			}
			animal_to_photos[aname].push_back(photoid);
			photo_to_features[photoid] = imgFeatures;
			imgFeatures = imgFeatures->clone();
		}
	}
	if (no_animal_name_error_count) {
		fprintf(stderr, "WARNING Cannot find animal name in %d from %d picture(s). ",
				no_animal_name_error_count, pic_count);
	}
	fclose(fp);

    if (animal_to_photos.empty()) {
        fprintf(stderr, "ERROR no animal in the database\n");
        exit(255);
    }

    // delete any animals with less than ipa+1 pictures
    int warning_count=0;
	for(map<string,vector<int> >::iterator itr = animal_to_photos.begin(); itr != animal_to_photos.end(); itr++) {
        if(itr->second.size() < (unsigned)ARG_ipa+1) {
        	warning_count++;
            animal_to_photos.erase(itr);
        }
	}
	if (warning_count>0) {
		fprintf(stderr, "WARNING %d/%lu animal(s) have < %d pictures, ignoring them.\n",
				warning_count, animal_to_photos.size()+warning_count, ARG_ipa+1);
	}


    // BUILD INDICES
	// list of animals = keys of the set of animal names
	for(map<string,vector<int> >::iterator itr = animal_to_photos.begin(); itr!=animal_to_photos.end(); itr++) {
		animals.push_back(itr->first);
		for(vector<int>::iterator it2 = itr->second.begin(); it2 != itr->second.end(); it2++) {
            photos.push_back(*it2);
            photo_to_animal[*it2] = itr->first;
		}
	}

	return 1;
}

// sample a random database of 'dbsize' animals, retaining one random image as the query image
// and 'ARG_ipa' images per animal as the database
//
// in other words, samples uniformly from the space of databases and query images
// INPUT:
//   dbsize  - number of animals
//   ARG_ipa - images per animal (global)
// OUTPUT:
//   QUERY   - query image (global)
//   DB      - database of photo IDs(global)
void sample_db_query_pair(int dbsize) {
	QUERY = -1;
	DB.clear();

	// create database of 'dbsize' animals, 'ARG_ipa' images per animal
	shuffle(animals);
	vector<int> remainder;
	for(int a = 0; a < dbsize; a++) {
		string animalID = animals[a];
		assert(animal_to_photos.find(animalID)!=animal_to_photos.end());
		vector<int> &photolist = animal_to_photos[animalID];
		shuffle(photolist);
		for(int i = 0; i < ARG_ipa; i++)
            DB.push_back(photolist[i]);
		for(int i = ARG_ipa; i < (int)photolist.size(); i++)
			remainder.push_back(photolist[i]);
	}

	// choose query image
	QUERY = remainder[RNG.genrand_real2()*remainder.size()];

    // sanity check to make sure DB does not contain query
    for(unsigned i = 0; i < DB.size(); i++)
        assert(DB[i] != QUERY);
}


void db_query_random(int &correctrank, double &querytime) {
	// a real complicated algorithm
	startClocks();
	shuffle(DB);
	querytime = stopClocks();

	// get the ground truth
	set<string> seen;
	string &correct = photo_to_animal[QUERY];

	// compute rank of correct animal (ranks start at 1)
	for(unsigned i = 0; i < DB.size(); i++) {
		string &a = photo_to_animal[DB[i]];
		seen.insert(a);
		if(a == correct) {
			correctrank = (int) seen.size();
			return;
		}
	}
	assert(0);
}

void db_query_generic(int &correctrank, double &querytime) {
    // rank database images by query image
    ImageFeatures *queryImg = photo_to_features[QUERY];
    multimap<double, int> ranking;
    startClocks();
    for(vector<int>::iterator i = DB.begin(); i != DB.end(); i++)
        ranking.insert(pair<double,int>(photo_to_features[*i]->compare(queryImg, NULL), *i));
    querytime = stopClocks();

    // compute correct rank
    set<string> seen;
    string &ground_truth = photo_to_animal[QUERY];
    for(multimap<double,int>::iterator photo = ranking.begin(); photo != ranking.end(); photo++) {
        string &aname = photo_to_animal[photo->second];
        seen.insert(aname);
        if(aname == ground_truth) {
            correctrank = (int)seen.size();
            return;
        }
    }
    assert(0);
}

int main(int argc, char *argv[]) {
	// setup
	ARG_randseed = time(NULL) % getpid();
	if(!ParseCommandLine(argc, argv, cmd))
		return 1;
	if(!ARG_quiet)
		fprintf(stderr, "StripeCode test and benchmarker. Copyright (c) 2010 Mayank Lahiri (mlahiri@gmail.com).\n\n");
	RNG.init_genrand(ARG_randseed);

	// select the algorithm we're using
	if(!strcmp(ARG_algorithm, "stripecode")) {
		db_query  = db_query_generic;
		imgFeatures= new StripeCode();
		if(ARG_sc_avg_cost)
            StripeCode::RETMINCOST = false;
        StripeCode::INDELCOST = ARG_sc_del_cost;
        StripeCode::USERATIOS = !ARG_sc_use_abslen;
	} else {
        if(!strcmp(ARG_algorithm, "random")) {
            db_query = db_query_random;
            imgFeatures= new StripeCode();
        } else {
            if(!strcmp(ARG_algorithm, "mrhisto")) {
                db_query = db_query_generic;
                imgFeatures= new MultiScaleHistogram();
            } else
                return fprintf(stderr, "Unknown algorithm '%s'\n", ARG_algorithm);
        }
    }

	// read SightingData.csv if the file exists
	if (!read_sighting_data()) {
		return 1;
	}

	// read and parse dataset file (stored in ARG_input)
	if(!read_dataset())
		return 1;

	// print header
	if(!ARG_quiet)
		fprintf(stderr, "# animals %d, photos %d, i.p.a. %d, random %d, algorithm '%s', "
                "sc_mincost %d, sc_indelcost %f, sc_useratios %d.\n",
				(int) animal_to_photos.size(), (int) photo_to_features.size(), 
                ARG_ipa, ARG_randseed, ARG_algorithm, StripeCode::RETMINCOST?1:0, 
                StripeCode::INDELCOST, StripeCode::USERATIOS?1:0);

	if(ARG_exploratory) {
	    wxInitAllImageHandlers();
        Exploratory();
        return 0;
	}

	// print header for R
//	printf("distance same_animal\n");
//	printf("photoid1 animal1 photoid2 animal2 distance same_animal\n");

	// run discrimination test
    const long nc2 = photos.size() * (photos.size()-1) / 2;
    long count = 0;
    time_t t0, t1;
    t0 = time(NULL);

    map<string,double> same_animal_sum, same_animal_num, same_animal_worst;
    map<string,double> diff_animal_sum, diff_animal_num, diff_animal_worst;
    for (int i=0;i<photos.size();i++) {
        const int photoid1 = photos[i];
        const string& animal1 = photo_to_animal[photoid1];
        ImageFeatures *img1 = photo_to_features[photoid1];
        for (int j=i+1;j<photos.size();j++) {
            const int photoid2 = photos[j];
            const string& animal2 = photo_to_animal[photoid2];
            ImageFeatures *img2 = photo_to_features[photoid2];

            double dist = img1->compare(img2, NULL);
            assert(isfinite(dist));
            int same = (animal1 == animal2);

            if (same) {
                same_animal_sum[animal1] += dist;
                same_animal_sum[animal2] += dist;
                same_animal_num[animal1]++;
                same_animal_num[animal2]++;
                if (same_animal_worst[animal1]<dist) {
                    same_animal_worst[animal1] = dist;
                }
                if (same_animal_worst[animal2]<dist) {
                    same_animal_worst[animal2] = dist;
                }
            } else {
                diff_animal_sum[animal1] += dist;
                diff_animal_sum[animal2] += dist;
                diff_animal_num[animal1]++;
                diff_animal_num[animal2]++;
                if (diff_animal_worst.find(animal1)==diff_animal_worst.end() ||
                        diff_animal_worst[animal1]>dist) {
                    diff_animal_worst[animal1] = dist;
                }
                if (diff_animal_worst.find(animal2)==diff_animal_worst.end() ||
                        diff_animal_worst[animal2]>dist) {
                    diff_animal_worst[animal2] = dist;
                }
            }

            //cout << photoid1 << " ";
            //cout << photoid2 << " ";
            //cout << animal1 << " ";
            //cout << animal2 << " ";
            //
            //cout << dist << " " << same << endl;

            count++;
            t1 = time(NULL);
            if (t1-t0>5) {
                cerr << "progress " << (10000*count/nc2)/100. << "%" << endl;
                t0 = t1;
            }
        }
    }

    cout << "animal avg_differece worst_difference" << endl;
    for (int i=0;i<animals.size();i++) {
        string& animal = animals[i];
        if (same_animal_num.find(animal)==same_animal_num.end()) continue;
        if (diff_animal_num.find(animal)==diff_animal_num.end()) continue;
        assert(isfinite(diff_animal_sum[animal]));
        assert(isfinite(same_animal_sum[animal]));
        assert(diff_animal_num[animal]>0);
        assert(same_animal_num[animal]>0);
        double diff_avg = diff_animal_sum[animal]/diff_animal_num[animal];
        double same_avg = same_animal_sum[animal]/same_animal_num[animal];
        assert(isfinite(diff_avg));
        assert(isfinite(same_avg));
        cout << animal << " " << diff_avg - same_avg << " ";
        cout << diff_animal_worst[animal] - same_animal_worst[animal] << endl;
    }
	return 0;
}

